{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7103ac7a",
   "metadata": {},
   "source": [
    "# Benchmarking Forecasters\n",
    "\n",
    "The `benchmarking` modules allows you to easily orchestrate benchmarking experiments in which you want to\n",
    "compare the performance of one or more algorithms over one or more datasets and benchmark configurations.\n",
    "\n",
    "Benchmarking as an endeavour in general is very easy to get wrong, giving false conclusions about estimator\n",
    "performance - see this [2022 research from Princeton](https://reproducible.cs.princeton.edu/)\n",
    "for numerous examples of such mistakes in peer reviewed academic papers as evidence of this.\n",
    "\n",
    "`sktime`'s `benchmarking` module is designed to provide benchmarking functionality while enforcing best\n",
    "practices and structure to help users avoid making mistakes (such as data leakage, etc.) which invalidate\n",
    "their results. The `benchmarking` module is designed for easy usage in mind, as such it interfaces\n",
    "directly with `sktime` objects and classes. Previously developed estimator should be usable as they are without\n",
    "alterations.\n",
    "\n",
    "This notebook demonstrates usage of the `benchmarking` module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c1bd1038",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sktime.benchmarking.forecasting import ForecastingBenchmark\n",
    "from sktime.datasets import load_airline\n",
    "from sktime.forecasting.naive import NaiveForecaster\n",
    "from sktime.performance_metrics.forecasting import MeanSquaredPercentageError\n",
    "from sktime.split import ExpandingWindowSplitter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d19b19c7",
   "metadata": {},
   "source": [
    "### Instantiate an instance of a benchmark class\n",
    "In this example we are comparing forecasting estimators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5ed25ac5",
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark = ForecastingBenchmark()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bbd5032",
   "metadata": {},
   "source": [
    "### Add competing estimators\n",
    "We add different competing estimators to the benchmark instance. All added estimators will \n",
    "be automatically ran through each added benchmark tasks, and their results compiled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d9122963",
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark.add_estimator(\n",
    "    estimator=NaiveForecaster(strategy=\"mean\", sp=12),\n",
    "    estimator_id=\"NaiveForecaster-mean-v1\",\n",
    ")\n",
    "benchmark.add_estimator(\n",
    "    estimator=NaiveForecaster(strategy=\"last\", sp=12),\n",
    "    estimator_id=\"NaiveForecaster-last-v1\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30056814",
   "metadata": {},
   "source": [
    "### Add benchmarking tasks\n",
    "These are the prediction/validation tasks over which every estimator will be tested and their results compiled.\n",
    "\n",
    "The exact arguments for a benchmarking task depend on the whether the objective is forecasting, classification, etc.,\n",
    "but generally they are similar. The following are the required arguments for defining a forecasting benchmark task."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3368d276",
   "metadata": {},
   "source": [
    "#### Specify cross-validation split regime(s)\n",
    "Define cross-validation split regimes, using standard `sktime` objects. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "812bd976",
   "metadata": {},
   "outputs": [],
   "source": [
    "cv_splitter = ExpandingWindowSplitter(\n",
    "    initial_window=24,\n",
    "    step_length=12,\n",
    "    fh=12,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa9e2a07",
   "metadata": {},
   "source": [
    "#### Specify performance metric(s)\n",
    "Define performance metrics on which to compare estimators, using standard `sktime` objects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8dde063e",
   "metadata": {},
   "outputs": [],
   "source": [
    "scorers = [MeanSquaredPercentageError()]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "374b66d3",
   "metadata": {},
   "source": [
    "#### Specify dataset loaders\n",
    "Define dataset loaders, which are callables (functions) which should return a dataset. Generally\n",
    "this is a callable which returns a dataframe containing the entire dataset. One can use\n",
    "the `sktime` defined datasets, or define their own. Something as simple as the following\n",
    "example will suffice: \n",
    "```python\n",
    "def my_dataset_loader():\n",
    "    return pd.read_csv(\"path/to/data.csv\")\n",
    "```\n",
    "The datasets will be loaded when running the benchmarking tasks, ran through the cross-validation\n",
    "regime(s) and subsequently the estimators will be tested over the dataset splits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d2f0f4f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_loaders = [load_airline]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8499be64",
   "metadata": {},
   "source": [
    "#### Add tasks to the benchmark instance\n",
    "Use the previously defined objects to add tasks to the benchmark instance.\n",
    "Optionally use loops etc. to easily setup multiple benchmark tasks reusing arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "35ae72d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_loader in dataset_loaders:\n",
    "    benchmark.add_task(\n",
    "        dataset_loader,\n",
    "        cv_splitter,\n",
    "        scorers,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "401deb16",
   "metadata": {},
   "source": [
    "### Run all task-estimator combinations and store results\n",
    "\n",
    "Note that `run` won't rerun tasks it already has results for, so adding a new\n",
    "estimator and running `run` again will only run tasks for that new estimator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5f75a779",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>validation_id</th>\n",
       "      <td>[dataset=load_airline]_[cv_splitter=ExpandingW...</td>\n",
       "      <td>[dataset=load_airline]_[cv_splitter=ExpandingW...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model_id</th>\n",
       "      <td>NaiveForecaster-last-v1</td>\n",
       "      <td>NaiveForecaster-mean-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>runtime_secs</th>\n",
       "      <td>0.061472</td>\n",
       "      <td>0.081733</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_0_test</th>\n",
       "      <td>0.024532</td>\n",
       "      <td>0.049681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_1_test</th>\n",
       "      <td>0.020831</td>\n",
       "      <td>0.0737</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_2_test</th>\n",
       "      <td>0.001213</td>\n",
       "      <td>0.05352</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_3_test</th>\n",
       "      <td>0.01495</td>\n",
       "      <td>0.081063</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_4_test</th>\n",
       "      <td>0.031067</td>\n",
       "      <td>0.138163</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_5_test</th>\n",
       "      <td>0.008373</td>\n",
       "      <td>0.145125</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_6_test</th>\n",
       "      <td>0.007972</td>\n",
       "      <td>0.154337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_7_test</th>\n",
       "      <td>0.000009</td>\n",
       "      <td>0.123298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_8_test</th>\n",
       "      <td>0.028191</td>\n",
       "      <td>0.185644</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_fold_9_test</th>\n",
       "      <td>0.003906</td>\n",
       "      <td>0.184654</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_mean</th>\n",
       "      <td>0.014104</td>\n",
       "      <td>0.118918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MeanSquaredPercentageError_std</th>\n",
       "      <td>0.011451</td>\n",
       "      <td>0.051265</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                        0  \\\n",
       "validation_id                           [dataset=load_airline]_[cv_splitter=ExpandingW...   \n",
       "model_id                                                          NaiveForecaster-last-v1   \n",
       "runtime_secs                                                                     0.061472   \n",
       "MeanSquaredPercentageError_fold_0_test                                           0.024532   \n",
       "MeanSquaredPercentageError_fold_1_test                                           0.020831   \n",
       "MeanSquaredPercentageError_fold_2_test                                           0.001213   \n",
       "MeanSquaredPercentageError_fold_3_test                                            0.01495   \n",
       "MeanSquaredPercentageError_fold_4_test                                           0.031067   \n",
       "MeanSquaredPercentageError_fold_5_test                                           0.008373   \n",
       "MeanSquaredPercentageError_fold_6_test                                           0.007972   \n",
       "MeanSquaredPercentageError_fold_7_test                                           0.000009   \n",
       "MeanSquaredPercentageError_fold_8_test                                           0.028191   \n",
       "MeanSquaredPercentageError_fold_9_test                                           0.003906   \n",
       "MeanSquaredPercentageError_mean                                                  0.014104   \n",
       "MeanSquaredPercentageError_std                                                   0.011451   \n",
       "\n",
       "                                                                                        1  \n",
       "validation_id                           [dataset=load_airline]_[cv_splitter=ExpandingW...  \n",
       "model_id                                                          NaiveForecaster-mean-v1  \n",
       "runtime_secs                                                                     0.081733  \n",
       "MeanSquaredPercentageError_fold_0_test                                           0.049681  \n",
       "MeanSquaredPercentageError_fold_1_test                                             0.0737  \n",
       "MeanSquaredPercentageError_fold_2_test                                            0.05352  \n",
       "MeanSquaredPercentageError_fold_3_test                                           0.081063  \n",
       "MeanSquaredPercentageError_fold_4_test                                           0.138163  \n",
       "MeanSquaredPercentageError_fold_5_test                                           0.145125  \n",
       "MeanSquaredPercentageError_fold_6_test                                           0.154337  \n",
       "MeanSquaredPercentageError_fold_7_test                                           0.123298  \n",
       "MeanSquaredPercentageError_fold_8_test                                           0.185644  \n",
       "MeanSquaredPercentageError_fold_9_test                                           0.184654  \n",
       "MeanSquaredPercentageError_mean                                                  0.118918  \n",
       "MeanSquaredPercentageError_std                                                   0.051265  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = benchmark.run(\"./forecasting_results.csv\")\n",
    "results_df.T"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
